{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import sys\n",
    "if \"../\" not in sys.path:\n",
    "  sys.path.append(\"../\") \n",
    "from lib.utils import read_data_from_file, sign, sigmoid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def logistic_reg(X, y, eta=0.1, epsilon=1e-6, updates=5000):\n",
    "    \"\"\"\n",
    "    Logistic Regression Algorithm\n",
    "    Args:\n",
    "        X: 数据\n",
    "        Y: 标签\n",
    "        eta: 步长\n",
    "        epsilon: 误差\n",
    "        updates: 迭代次数\n",
    "    \n",
    "    Returns:\n",
    "        w: 特征权重\n",
    "        gradient: 梯度 \n",
    "    \"\"\"\n",
    "    w = np.zeros_like(X[0])\n",
    "    gradient = np.zeros_like(w)\n",
    "    \n",
    "    for i in range(1, updates + 1):\n",
    "        if i%10 == 0:\n",
    "            sys.stdout.flush()\n",
    "            print('\\repisode: {}'.format(i), end='')\n",
    "        \n",
    "        # step1: cal gradient\n",
    "        gradient = np.mean((sigmoid(X.dot(w) * (-y))* -y).reshape((-1, 1))*X, axis=0)\n",
    "        \n",
    "        # step2: update weight\n",
    "        w = w - eta*gradient\n",
    "        \n",
    "        if np.linalg.norm(gradient) <= epsilon:\n",
    "            break\n",
    "    return w, gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def logistic_reg_sgd(X, y, eta=0.05, epsilon=1e-6, updates=20000):\n",
    "    \"\"\"\n",
    "    Logistic Regression Stochastic Gradient Descent Algorithm\n",
    "    Args:\n",
    "        X: 数据\n",
    "        Y: 标签\n",
    "        eta: 步长\n",
    "        epsilon: 误差\n",
    "        updates: 迭代次数\n",
    "    \n",
    "    Returns:\n",
    "        w: 特征权重\n",
    "        gradient: 梯度\n",
    "    \"\"\"    \n",
    "    w = np.zeros_like(X[0])\n",
    "    gradient = np.zeros_like(w)\n",
    "    \n",
    "    for i in range(1, updates + 1):\n",
    "        if i%50 == 0:\n",
    "            sys.stdout.flush()\n",
    "            print('\\repisode: {}\\t'.format(i), end='')\n",
    "        \n",
    "        # step1: random pick up one example\n",
    "        index = np.random.choice(range(len(X)))\n",
    "        \n",
    "        # step1: cal sgd\n",
    "        gradient = (sigmoid(X[index].dot(w) * (-y[index]))* -y[index])*X[index]\n",
    "        \n",
    "        # step2: update weight\n",
    "        w = w - eta*gradient\n",
    "        \n",
    "        if np.linalg.norm(gradient) <= epsilon or i>=updates:\n",
    "            break\n",
    "    return w, gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_train shape:  (500, 5)\n",
      "data_test shape:  (500, 5)\n"
     ]
    }
   ],
   "source": [
    "# 数据读取\n",
    "data_train = read_data_from_file('hw1_18_train.dat') \n",
    "print('data_train shape: ', data_train.shape)\n",
    "data_test = read_data_from_file('hw1_18_test.dat')\n",
    "print('data_test shape: ', data_test.shape)\n",
    "y = data_train[:,-1]\n",
    "X = np.concatenate((np.ones((data_train.shape[0],1)), data_train[:,:-1]), axis=1)\n",
    "y_test = data_test[:,-1]\n",
    "X_test = np.concatenate((np.ones((data_test.shape[0],1)), data_test[:,:-1]), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_err_rate(X, y, w):\n",
    "    y_hat = sign(X.dot(w))\n",
    "    err_rate = (y_hat != y).mean()\n",
    "    return err_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# logistic regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 5000\n",
      "logistic reg for classification:\n",
      "error rate Ein: 0.1\n",
      "error rate Eout: 0.1\n"
     ]
    }
   ],
   "source": [
    "# logistic reg\n",
    "w_reg, gradient = logistic_reg(X, y)\n",
    "print()\n",
    "print('logistic reg for classification:')\n",
    "\n",
    "err_rates = []\n",
    "err_rates.append(get_err_rate(X, y, w_reg))\n",
    "print(\"error rate Ein: {}\".format(np.mean(np.array(err_rates))))\n",
    "\n",
    "err_rates = []\n",
    "err_rates.append(get_err_rate(X_test, y_test, w_reg))\n",
    "print(\"error rate Eout: {}\".format(np.mean(np.array(err_rates))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SGD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 20000\t\n",
      "logistic reg sgd for classification:\n",
      "error rate Ein: 0.114\n",
      "error rate Eout: 0.122\n"
     ]
    }
   ],
   "source": [
    "w_reg_sgd, gradient = logistic_reg_sgd(X, y)\n",
    "print()\n",
    "print('logistic reg sgd for classification:')\n",
    "\n",
    "err_rates = []\n",
    "err_rates.append(get_err_rate(X, y, w_reg_sgd))\n",
    "print(\"error rate Ein: {}\".format(np.mean(np.array(err_rates))))\n",
    "\n",
    "err_rates = []\n",
    "err_rates.append(get_err_rate(X_test, y_test, w_reg_sgd))\n",
    "print(\"error rate Eout: {}\".format(np.mean(np.array(err_rates))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
